#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""A PipelineRunner using the SDK harness.
"""
from __future__ import absolute_import
from __future__ import print_function

import collections
import contextlib
import copy
import itertools
import logging
import os
import queue
import subprocess
import sys
import threading
import time
import uuid
from builtins import object
from concurrent import futures

import grpc

import apache_beam as beam  # pylint: disable=ungrouped-imports
from apache_beam import coders
from apache_beam import metrics
from apache_beam.coders.coder_impl import create_InputStream
from apache_beam.coders.coder_impl import create_OutputStream
from apache_beam.metrics import monitoring_infos
from apache_beam.metrics.execution import MetricKey
from apache_beam.metrics.execution import MetricsEnvironment
from apache_beam.metrics.metricbase import MetricName
from apache_beam.options import pipeline_options
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.portability import common_urns
from apache_beam.portability import python_urns
from apache_beam.portability.api import beam_artifact_api_pb2_grpc
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.portability.api import beam_provision_api_pb2
from apache_beam.portability.api import beam_provision_api_pb2_grpc
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.portability.api import endpoints_pb2
from apache_beam.runners import pipeline_context
from apache_beam.runners import runner
from apache_beam.runners.portability import artifact_service
from apache_beam.runners.portability import fn_api_runner_transforms
from apache_beam.runners.portability.fn_api_runner_transforms import create_buffer_id
from apache_beam.runners.portability.fn_api_runner_transforms import only_element
from apache_beam.runners.portability.fn_api_runner_transforms import split_buffer_id
from apache_beam.runners.portability.fn_api_runner_transforms import unique_name
from apache_beam.runners.worker import bundle_processor
from apache_beam.runners.worker import data_plane
from apache_beam.runners.worker import sdk_worker
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
from apache_beam.transforms import trigger
from apache_beam.transforms.window import GlobalWindows
from apache_beam.utils import profiler
from apache_beam.utils import proto_utils

# This module is experimental. No backwards-compatibility guarantees.

ENCODED_IMPULSE_VALUE = beam.coders.WindowedValueCoder(
    beam.coders.BytesCoder(),
    beam.coders.coders.GlobalWindowCoder()).get_impl().encode_nested(
        beam.transforms.window.GlobalWindows.windowed_value(b''))


class ControlConnection(object):

  _uid_counter = 0
  _lock = threading.Lock()

  def __init__(self):
    self._push_queue = queue.Queue()
    self._input = None
    self._futures_by_id = dict()
    self._read_thread = threading.Thread(
        name='beam_control_read', target=self._read)
    self._state = BeamFnControlServicer.UNSTARTED_STATE

  def _read(self):
    for data in self._input:
      self._futures_by_id.pop(data.instruction_id).set(data)

  def push(self, req):
    if req == BeamFnControlServicer._DONE_MARKER:
      self._push_queue.put(req)
      return None
    if not req.instruction_id:
      with ControlConnection._lock:
        ControlConnection._uid_counter += 1
        req.instruction_id = 'control_%s' % ControlConnection._uid_counter
    future = ControlFuture(req.instruction_id)
    self._futures_by_id[req.instruction_id] = future
    self._push_queue.put(req)
    return future

  def get_req(self):
    return self._push_queue.get()

  def set_input(self, input):
    with ControlConnection._lock:
      if self._input:
        raise RuntimeError('input is already set.')
      self._input = input
      self._read_thread.start()
      self._state = BeamFnControlServicer.STARTED_STATE

  def close(self):
    with ControlConnection._lock:
      if self._state == BeamFnControlServicer.STARTED_STATE:
        self.push(BeamFnControlServicer._DONE_MARKER)
        self._read_thread.join()
      self._state = BeamFnControlServicer.DONE_STATE


class BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer):
  """Implementation of BeamFnControlServicer for clients."""

  UNSTARTED_STATE = 'unstarted'
  STARTED_STATE = 'started'
  DONE_STATE = 'done'

  _DONE_MARKER = object()

  def __init__(self):
    self._lock = threading.Lock()
    self._uid_counter = 0
    self._state = self.UNSTARTED_STATE
    # following self._req_* variables are used for debugging purpose, data is
    # added only when self._log_req is True.
    self._req_sent = collections.defaultdict(int)
    self._req_worker_mapping = {}
    self._log_req = logging.getLogger().getEffectiveLevel() <= logging.DEBUG
    self._connections_by_worker_id = collections.defaultdict(ControlConnection)

  def get_conn_by_worker_id(self, worker_id):
    with self._lock:
      return self._connections_by_worker_id[worker_id]

  def Control(self, iterator, context):
    with self._lock:
      if self._state == self.DONE_STATE:
        return
      else:
        self._state = self.STARTED_STATE

    worker_id = dict(context.invocation_metadata()).get('worker_id')
    if not worker_id:
      raise RuntimeError('All workers communicate through gRPC should have '
                         'worker_id. Received None.')

    control_conn = self.get_conn_by_worker_id(worker_id)
    control_conn.set_input(iterator)

    while True:
      to_push = control_conn.get_req()
      if to_push is self._DONE_MARKER:
        return
      yield to_push
      if self._log_req:
        self._req_sent[to_push.instruction_id] += 1

  def done(self):
    self._state = self.DONE_STATE
    logging.debug('Runner: Requests sent by runner: %s',
                  [(str(req), cnt) for req, cnt in self._req_sent.items()])
    logging.debug('Runner: Requests multiplexing info: %s',
                  [(str(req), worker) for req, worker
                   in self._req_worker_mapping.items()])


class _ListBuffer(list):
  """Used to support parititioning of a list."""
  def partition(self, n):
    return [self[k::n] for k in range(n)]


class _GroupingBuffer(object):
  """Used to accumulate groupded (shuffled) results."""
  def __init__(self, pre_grouped_coder, post_grouped_coder, windowing):
    self._key_coder = pre_grouped_coder.key_coder()
    self._pre_grouped_coder = pre_grouped_coder
    self._post_grouped_coder = post_grouped_coder
    self._table = collections.defaultdict(list)
    self._windowing = windowing
    self._grouped_output = None

  def append(self, elements_data):
    if self._grouped_output:
      raise RuntimeError('Grouping table append after read.')
    input_stream = create_InputStream(elements_data)
    coder_impl = self._pre_grouped_coder.get_impl()
    key_coder_impl = self._key_coder.get_impl()
    # TODO(robertwb): We could optimize this even more by using a
    # window-dropping coder for the data plane.
    is_trivial_windowing = self._windowing.is_default()
    while input_stream.size() > 0:
      windowed_key_value = coder_impl.decode_from_stream(input_stream, True)
      key, value = windowed_key_value.value
      self._table[key_coder_impl.encode(key)].append(
          value if is_trivial_windowing
          else windowed_key_value.with_value(value))

  def partition(self, n):
    """ It is used to partition _GroupingBuffer to N parts. Once it is
    partitioned, it would not be re-partitioned with diff N. Re-partition
    is not supported now.
    """
    if not self._grouped_output:
      if self._windowing.is_default():
        globally_window = GlobalWindows.windowed_value(None).with_value
        windowed_key_values = lambda key, values: [
            globally_window((key, values))]
      else:
        # TODO(pabloem, BEAM-7514): Trigger driver needs access to the clock
        #   note that this only comes through if windowing is default - but what
        #   about having multiple firings on the global window.
        #   May need to revise.
        trigger_driver = trigger.create_trigger_driver(self._windowing, True)
        windowed_key_values = trigger_driver.process_entire_key
      coder_impl = self._post_grouped_coder.get_impl()
      key_coder_impl = self._key_coder.get_impl()
      self._grouped_output = [[] for _ in range(n)]
      output_stream_list = []
      for _ in range(n):
        output_stream_list.append(create_OutputStream())
      for idx, (encoded_key, windowed_values) in enumerate(self._table.items()):
        key = key_coder_impl.decode(encoded_key)
        for wkvs in windowed_key_values(key, windowed_values):
          coder_impl.encode_to_stream(wkvs, output_stream_list[idx % n], True)
      for ix, output_stream in enumerate(output_stream_list):
        self._grouped_output[ix] = [output_stream.get()]
      self._table = None
    return self._grouped_output

  def __iter__(self):
    """ Since partition() returns a list of lists, add this __iter__ to return
    a list to simplify code when we need to iterate through ALL elements of
    _GroupingBuffer.
    """
    return itertools.chain(*self.partition(1))


class _WindowGroupingBuffer(object):
  """Used to partition windowed side inputs."""
  def __init__(self, access_pattern, coder):
    # Here's where we would use a different type of partitioning
    # (e.g. also by key) for a different access pattern.
    if access_pattern.urn == common_urns.side_inputs.ITERABLE.urn:
      self._kv_extrator = lambda value: ('', value)
      self._key_coder = coders.SingletonCoder('')
      self._value_coder = coder.wrapped_value_coder
    elif access_pattern.urn == common_urns.side_inputs.MULTIMAP.urn:
      self._kv_extrator = lambda value: value
      self._key_coder = coder.wrapped_value_coder.key_coder()
      self._value_coder = (
          coder.wrapped_value_coder.value_coder())
    else:
      raise ValueError(
          "Unknown access pattern: '%s'" % access_pattern.urn)
    self._windowed_value_coder = coder
    self._window_coder = coder.window_coder
    self._values_by_window = collections.defaultdict(list)

  def append(self, elements_data):
    input_stream = create_InputStream(elements_data)
    while input_stream.size() > 0:
      windowed_value = self._windowed_value_coder.get_impl(
          ).decode_from_stream(input_stream, True)
      key, value = self._kv_extrator(windowed_value.value)
      for window in windowed_value.windows:
        self._values_by_window[key, window].append(value)

  def encoded_items(self):
    value_coder_impl = self._value_coder.get_impl()
    key_coder_impl = self._key_coder.get_impl()
    for (key, window), values in self._values_by_window.items():
      encoded_window = self._window_coder.encode(window)
      encoded_key = key_coder_impl.encode_nested(key)
      output_stream = create_OutputStream()
      for value in values:
        value_coder_impl.encode_to_stream(value, output_stream, True)
      yield encoded_key, encoded_window, output_stream.get()


class FnApiRunner(runner.PipelineRunner):

  def __init__(
      self,
      default_environment=None,
      bundle_repeat=0,
      use_state_iterables=False,
      provision_info=None):
    """Creates a new Fn API Runner.

    Args:
      default_environment: the default environment to use for UserFns.
      bundle_repeat: replay every bundle this many extra times, for profiling
          and debugging
      use_state_iterables: Intentionally split gbk iterables over state API
          (for testing)
      provision_info: provisioning info to make available to workers, or None
    """
    super(FnApiRunner, self).__init__()
    self._last_uid = -1
    self._default_environment = (
        default_environment
        or beam_runner_api_pb2.Environment(urn=python_urns.EMBEDDED_PYTHON))
    self._bundle_repeat = bundle_repeat
    self._num_workers = 1
    self._progress_frequency = None
    self._profiler_factory = None
    self._use_state_iterables = use_state_iterables
    self._provision_info = provision_info

  def _next_uid(self):
    self._last_uid += 1
    return str(self._last_uid)

  def run_pipeline(self, pipeline, options):
    MetricsEnvironment.set_metrics_supported(False)
    RuntimeValueProvider.set_runtime_options({})

    # Setup "beam_fn_api" experiment options if lacked.
    experiments = (options.view_as(pipeline_options.DebugOptions).experiments
                   or [])
    if not 'beam_fn_api' in experiments:
      experiments.append('beam_fn_api')
    options.view_as(pipeline_options.DebugOptions).experiments = experiments

    # This is sometimes needed if type checking is disabled
    # to enforce that the inputs (and outputs) of GroupByKey operations
    # are known to be KVs.
    from apache_beam.runners.dataflow.dataflow_runner import DataflowRunner
    # TODO: Move group_by_key_input_visitor() to a non-dataflow specific file.
    pipeline.visit(DataflowRunner.group_by_key_input_visitor())
    self._bundle_repeat = self._bundle_repeat or options.view_as(
        pipeline_options.DirectOptions).direct_runner_bundle_repeat
    self._num_workers = options.view_as(
        pipeline_options.DirectOptions).direct_num_workers or self._num_workers
    self._profiler_factory = profiler.Profile.factory_from_options(
        options.view_as(pipeline_options.ProfilingOptions))

    if 'use_sdf_bounded_source' in experiments:
      pipeline.replace_all(DataflowRunner._SDF_PTRANSFORM_OVERRIDES)

    self._latest_run_result = self.run_via_runner_api(pipeline.to_runner_api(
        default_environment=self._default_environment))
    return self._latest_run_result

  def run_via_runner_api(self, pipeline_proto):
    stage_context, stages = self.create_stages(pipeline_proto)
    # TODO(pabloem, BEAM-7514): Create a watermark manager (that has access to
    #   the teststream (if any), and all the stages).
    return self.run_stages(stage_context, stages)

  @contextlib.contextmanager
  def maybe_profile(self):
    if self._profiler_factory:
      try:
        profile_id = 'direct-' + subprocess.check_output(
            ['git', 'rev-parse', '--abbrev-ref', 'HEAD']
        ).decode(errors='ignore').strip()
      except subprocess.CalledProcessError:
        profile_id = 'direct-unknown'
      profiler = self._profiler_factory(profile_id, time_prefix='')
    else:
      profiler = None

    if profiler:
      with profiler:
        yield
      if not self._bundle_repeat:
        logging.warning(
            'The --direct_runner_bundle_repeat option is not set; '
            'a significant portion of the profile may be one-time overhead.')
      path = profiler.profile_output
      print('CPU Profile written to %s' % path)
      try:
        import gprof2dot  # pylint: disable=unused-variable
        if not subprocess.call([
            sys.executable, '-m', 'gprof2dot',
            '-f', 'pstats', path, '-o', path + '.dot']):
          if not subprocess.call(
              ['dot', '-Tsvg', '-o', path + '.svg', path + '.dot']):
            print('CPU Profile rendering at file://%s.svg'
                  % os.path.abspath(path))
      except ImportError:
        # pylint: disable=superfluous-parens
        print('Please install gprof2dot and dot for profile renderings.')

    else:
      # Empty context.
      yield

  def create_stages(self, pipeline_proto):
    return fn_api_runner_transforms.create_and_optimize_stages(
        copy.deepcopy(pipeline_proto),
        phases=[fn_api_runner_transforms.annotate_downstream_side_inputs,
                fn_api_runner_transforms.fix_side_input_pcoll_coders,
                fn_api_runner_transforms.lift_combiners,
                fn_api_runner_transforms.expand_sdf,
                fn_api_runner_transforms.expand_gbk,
                fn_api_runner_transforms.sink_flattens,
                fn_api_runner_transforms.greedily_fuse,
                fn_api_runner_transforms.read_to_impulse,
                fn_api_runner_transforms.impulse_to_input,
                fn_api_runner_transforms.inject_timer_pcollections,
                fn_api_runner_transforms.sort_stages,
                fn_api_runner_transforms.window_pcollection_coders],
        known_runner_urns=frozenset([
            common_urns.primitives.FLATTEN.urn,
            common_urns.primitives.GROUP_BY_KEY.urn]),
        use_state_iterables=self._use_state_iterables)

  def run_stages(self, stage_context, stages):
    """Run a list of topologically-sorted stages in batch mode.

    Args:
      stage_context (fn_api_runner_transforms.TransformContext)
      stages (list[fn_api_runner_transforms.Stage])
    """
    worker_handler_manager = WorkerHandlerManager(
        stage_context.components.environments, self._provision_info)
    metrics_by_stage = {}
    monitoring_infos_by_stage = {}

    try:
      with self.maybe_profile():
        pcoll_buffers = collections.defaultdict(_ListBuffer)
        for stage in stages:
          stage_results = self._run_stage(
              worker_handler_manager.get_worker_handlers,
              stage_context.components,
              stage,
              pcoll_buffers,
              stage_context.safe_coders)
          metrics_by_stage[stage.name] = stage_results.process_bundle.metrics
          monitoring_infos_by_stage[stage.name] = (
              stage_results.process_bundle.monitoring_infos)
    finally:
      worker_handler_manager.close_all()
    return RunnerResult(
        runner.PipelineState.DONE, monitoring_infos_by_stage, metrics_by_stage)

  def _store_side_inputs_in_state(self,
                                  worker_handler,
                                  context,
                                  pipeline_components,
                                  data_side_input,
                                  pcoll_buffers,
                                  safe_coders):
    for (transform_id, tag), (buffer_id, si) in data_side_input.items():
      _, pcoll_id = split_buffer_id(buffer_id)
      value_coder = context.coders[safe_coders[
          pipeline_components.pcollections[pcoll_id].coder_id]]
      elements_by_window = _WindowGroupingBuffer(si, value_coder)
      for element_data in pcoll_buffers[buffer_id]:
        elements_by_window.append(element_data)
      for key, window, elements_data in elements_by_window.encoded_items():
        state_key = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                ptransform_id=transform_id,
                side_input_id=tag,
                window=window,
                key=key))
        worker_handler.state.blocking_append(state_key, elements_data)

  def _run_bundle_multiple_times_for_testing(
      self, worker_handler_list, process_bundle_descriptor, data_input,
      data_output, get_input_coder_callable):

    # all workers share state, so use any worker_handler.
    worker_handler = worker_handler_list[0]
    for k in range(self._bundle_repeat):
      try:
        worker_handler.state.checkpoint()
        ParallelBundleManager(
            worker_handler_list, lambda pcoll_id: [],
            get_input_coder_callable, process_bundle_descriptor,
            self._progress_frequency, k, num_workers=self._num_workers
        ).process_bundle(data_input, data_output)
      finally:
        worker_handler.state.restore()

  def _collect_written_timers_and_add_to_deferred_inputs(self,
                                                         context,
                                                         pipeline_components,
                                                         stage,
                                                         get_buffer_callable,
                                                         deferred_inputs):

    for transform_id, timer_writes in stage.timer_pcollections:

      # Queue any set timers as new inputs.
      windowed_timer_coder_impl = context.coders[
          pipeline_components.pcollections[timer_writes].coder_id].get_impl()
      written_timers = get_buffer_callable(
          create_buffer_id(timer_writes, kind='timers'))
      if written_timers:
        # Keep only the "last" timer set per key and window.
        timers_by_key_and_window = {}
        for elements_data in written_timers:
          input_stream = create_InputStream(elements_data)
          while input_stream.size() > 0:
            windowed_key_timer = windowed_timer_coder_impl.decode_from_stream(
                input_stream, True)
            key, _ = windowed_key_timer.value
            # TODO: Explode and merge windows.
            assert len(windowed_key_timer.windows) == 1
            timers_by_key_and_window[
                key, windowed_key_timer.windows[0]] = windowed_key_timer
        out = create_OutputStream()
        for windowed_key_timer in timers_by_key_and_window.values():
          windowed_timer_coder_impl.encode_to_stream(
              windowed_key_timer, out, True)
        deferred_inputs[transform_id] = _ListBuffer([out.get()])
        written_timers[:] = []

  def _add_residuals_and_channel_splits_to_deferred_inputs(
      self, splits, get_input_coder_callable,
      input_for_callable, last_sent, deferred_inputs):
    prev_stops = {}
    for split in splits:
      for delayed_application in split.residual_roots:
        deferred_inputs[
            input_for_callable(
                delayed_application.application.ptransform_id,
                delayed_application.application.input_id)
        ].append(delayed_application.application.element)
      for channel_split in split.channel_splits:
        coder_impl = get_input_coder_callable(channel_split.ptransform_id)
        # TODO(SDF): This requires determanistic ordering of buffer iteration.
        # TODO(SDF): The return split is in terms of indices.  Ideally,
        # a runner could map these back to actual positions to effectively
        # describe the two "halves" of the now-split range.  Even if we have
        # to buffer each element we send (or at the very least a bit of
        # metadata, like position, about each of them) this should be doable
        # if they're already in memory and we are bounding the buffer size
        # (e.g. to 10mb plus whatever is eagerly read from the SDK).  In the
        # case of non-split-points, we can either immediately replay the
        # "non-split-position" elements or record them as we do the other
        # delayed applications.

        # Decode and recode to split the encoded buffer by element index.
        all_elements = list(coder_impl.decode_all(b''.join(last_sent[
            channel_split.ptransform_id])))
        residual_elements = all_elements[
            channel_split.first_residual_element : prev_stops.get(
                channel_split.ptransform_id, len(all_elements)) + 1]
        if residual_elements:
          deferred_inputs[channel_split.ptransform_id].append(
              coder_impl.encode_all(residual_elements))
        prev_stops[
            channel_split.ptransform_id] = channel_split.last_primary_element

  @staticmethod
  def _extract_stage_data_endpoints(
      stage, pipeline_components, data_api_service_descriptor, pcoll_buffers):
    # Returns maps of transform names to PCollection identifiers.
    # Also mutates IO stages to point to the data ApiServiceDescriptor.
    data_input = {}
    data_side_input = {}
    data_output = {}
    for transform in stage.transforms:
      if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                bundle_processor.DATA_OUTPUT_URN):
        pcoll_id = transform.spec.payload
        if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
          target = transform.unique_name, only_element(transform.outputs)
          if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER:
            data_input[target] = _ListBuffer([ENCODED_IMPULSE_VALUE])
          else:
            data_input[target] = pcoll_buffers[pcoll_id]
          coder_id = pipeline_components.pcollections[
              only_element(transform.outputs.values())].coder_id
        elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
          target = transform.unique_name, only_element(transform.inputs)
          data_output[target] = pcoll_id
          coder_id = pipeline_components.pcollections[
              only_element(transform.inputs.values())].coder_id
        else:
          raise NotImplementedError
        data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
        if data_api_service_descriptor:
          data_spec.api_service_descriptor.url = (
              data_api_service_descriptor.url)
        transform.spec.payload = data_spec.SerializeToString()
      elif transform.spec.urn in fn_api_runner_transforms.PAR_DO_URNS:
        payload = proto_utils.parse_Bytes(
            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
        for tag, si in payload.side_inputs.items():
          data_side_input[transform.unique_name, tag] = (
              create_buffer_id(transform.inputs[tag]), si.access_pattern)
    return data_input, data_side_input, data_output

  def _run_stage(self,
                 worker_handler_factory,
                 pipeline_components,
                 stage,
                 pcoll_buffers,
                 safe_coders):
    """Run an individual stage.

    Args:
      worker_handler_factory: A ``callable`` that takes in an environment, and
        returns a ``WorkerHandler`` class.
      pipeline_components (beam_runner_api_pb2.Components): TODO
      stage (fn_api_runner_transforms.Stage)
      pcoll_buffers (collections.defaultdict of str: list): Mapping of
        PCollection IDs to list that functions as buffer for the
        ``beam.PCollection``.
      safe_coders (dict): TODO
    """
    def iterable_state_write(values, element_coder_impl):
      token = unique_name(None, 'iter').encode('ascii')
      out = create_OutputStream()
      for element in values:
        element_coder_impl.encode_to_stream(element, out, True)
      worker_handler.state.blocking_append(
          beam_fn_api_pb2.StateKey(
              runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
          out.get())
      return token

    worker_handler_list = worker_handler_factory(
        stage.environment, self._num_workers)

    # All worker_handlers share the same grpc server, so we can read grpc server
    # info from any worker_handler and read from the first worker_handler.
    worker_handler = next(iter(worker_handler_list))
    context = pipeline_context.PipelineContext(
        pipeline_components, iterable_state_write=iterable_state_write)
    data_api_service_descriptor = worker_handler.data_api_service_descriptor()

    logging.info('Running %s', stage.name)
    data_input, data_side_input, data_output = self._extract_endpoints(
        stage, pipeline_components, data_api_service_descriptor, pcoll_buffers)

    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
        id=self._next_uid(),
        transforms={transform.unique_name: transform
                    for transform in stage.transforms},
        pcollections=dict(pipeline_components.pcollections.items()),
        coders=dict(pipeline_components.coders.items()),
        windowing_strategies=dict(
            pipeline_components.windowing_strategies.items()),
        environments=dict(pipeline_components.environments.items()))

    if worker_handler.state_api_service_descriptor():
      process_bundle_descriptor.state_api_service_descriptor.url = (
          worker_handler.state_api_service_descriptor().url)

    # Store the required side inputs into state so it is accessible for the
    # worker when it runs this bundle.
    self._store_side_inputs_in_state(worker_handler,
                                     context,
                                     pipeline_components,
                                     data_side_input,
                                     pcoll_buffers,
                                     safe_coders)

    def get_buffer(buffer_id):
      """Returns the buffer for a given (operation_type, PCollection ID).

      For grouping-typed operations, we produce a ``_GroupingBuffer``. For
      others, we produce a ``_ListBuffer``.
      """
      kind, name = split_buffer_id(buffer_id)
      if kind in ('materialize', 'timers'):
        # If `buffer_id` is not a key in `pcoll_buffers`, it will be added by
        # the `defaultdict`.
        return pcoll_buffers[buffer_id]
      elif kind == 'group':
        # This is a grouping write, create a grouping buffer if needed.
        if buffer_id not in pcoll_buffers:
          original_gbk_transform = name
          transform_proto = pipeline_components.transforms[
              original_gbk_transform]
          input_pcoll = only_element(list(transform_proto.inputs.values()))
          output_pcoll = only_element(list(transform_proto.outputs.values()))
          pre_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[input_pcoll].coder_id]]
          post_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[output_pcoll].coder_id]]
          windowing_strategy = context.windowing_strategies[
              pipeline_components
              .pcollections[output_pcoll].windowing_strategy_id]
          pcoll_buffers[buffer_id] = _GroupingBuffer(
              pre_gbk_coder, post_gbk_coder, windowing_strategy)
      else:
        # These should be the only two identifiers we produce for now,
        # but special side input writes may go here.
        raise NotImplementedError(buffer_id)
      return pcoll_buffers[buffer_id]

    def get_input_coder_impl(transform_id):
      return context.coders[safe_coders[
          beam_fn_api_pb2.RemoteGrpcPort.FromString(
              process_bundle_descriptor.transforms[transform_id].spec.payload
          ).coder_id
      ]].get_impl()

    self._run_bundle_multiple_times_for_testing(worker_handler_list,
                                                process_bundle_descriptor,
                                                data_input,
                                                data_output,
                                                get_input_coder_impl)

    bundle_manager = ParallelBundleManager(
        worker_handler_list, get_buffer, get_input_coder_impl,
        process_bundle_descriptor, self._progress_frequency,
        num_workers=self._num_workers)

    result, splits = bundle_manager.process_bundle(data_input, data_output)

    def input_for(ptransform_id, input_id):
      input_pcoll = process_bundle_descriptor.transforms[
          ptransform_id].inputs[input_id]
      for read_id, proto in process_bundle_descriptor.transforms.items():
        if (proto.spec.urn == bundle_processor.DATA_INPUT_URN
            and input_pcoll in proto.outputs.values()):
          return read_id
      raise RuntimeError(
          'No IO transform feeds %s' % ptransform_id)

    last_result = result
    last_sent = data_input

    while True:
      deferred_inputs = collections.defaultdict(_ListBuffer)

      self._collect_written_timers_and_add_to_deferred_inputs(
          context, pipeline_components, stage, get_buffer, deferred_inputs)

      # Queue any process-initiated delayed bundle applications.
      for delayed_application in last_result.process_bundle.residual_roots:
        deferred_inputs[
            input_for(
                delayed_application.application.ptransform_id,
                delayed_application.application.input_id)
        ].append(delayed_application.application.element)

      # Queue any runner-initiated delayed bundle applications.
      self._add_residuals_and_channel_splits_to_deferred_inputs(
          splits, get_input_coder_impl, input_for, last_sent, deferred_inputs)

      if deferred_inputs:
        # The worker will be waiting on these inputs as well.
        for other_input in data_input:
          if other_input not in deferred_inputs:
            deferred_inputs[other_input] = _ListBuffer([])
        # TODO(robertwb): merge results
        # We cannot split deferred_input until we include residual_roots to
        # merged results. Without residual_roots, pipeline stops earlier and we
        # may miss some data.
        bundle_manager._num_workers = 1
        bundle_manager._skip_registration = True
        last_result, splits = bundle_manager.process_bundle(
            deferred_inputs, data_output)
        last_sent = deferred_inputs
        result = beam_fn_api_pb2.InstructionResponse(
            process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
                monitoring_infos=monitoring_infos.consolidate(
                    itertools.chain(
                        result.process_bundle.monitoring_infos,
                        last_result.process_bundle.monitoring_infos))),
            error=result.error or last_result.error)
      else:
        break

    return result

  @staticmethod
  def _extract_endpoints(stage,
                         pipeline_components,
                         data_api_service_descriptor,
                         pcoll_buffers):
    """Returns maps of transform names to PCollection identifiers.

    Also mutates IO stages to point to the data ApiServiceDescriptor.

    Args:
      stage (fn_api_runner_transforms.Stage): The stage to extract endpoints
        for.
      pipeline_components (beam_runner_api_pb2.Components): Components of the
        pipeline to include coders, transforms, PCollections, etc.
      data_api_service_descriptor: A GRPC endpoint descriptor for data plane.
      pcoll_buffers (dict): A dictionary containing buffers for PCollection
        elements.
    Returns:
      A tuple of (data_input, data_side_input, data_output) dictionaries.
        `data_input` is a dictionary mapping (transform_name, output_name) to a
        PCollection buffer; `data_output` is a dictionary mapping
        (transform_name, output_name) to a PCollection ID.
    """
    data_input = {}
    data_side_input = {}
    data_output = {}
    for transform in stage.transforms:
      if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                bundle_processor.DATA_OUTPUT_URN):
        pcoll_id = transform.spec.payload
        if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
          if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER:
            data_input[transform.unique_name] = _ListBuffer(
                [ENCODED_IMPULSE_VALUE])
          else:
            data_input[transform.unique_name] = pcoll_buffers[pcoll_id]
          coder_id = pipeline_components.pcollections[
              only_element(transform.outputs.values())].coder_id
        elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
          data_output[transform.unique_name] = pcoll_id
          coder_id = pipeline_components.pcollections[
              only_element(transform.inputs.values())].coder_id
        else:
          raise NotImplementedError
        data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
        if data_api_service_descriptor:
          data_spec.api_service_descriptor.url = (
              data_api_service_descriptor.url)
        transform.spec.payload = data_spec.SerializeToString()
      elif transform.spec.urn in fn_api_runner_transforms.PAR_DO_URNS:
        payload = proto_utils.parse_Bytes(
            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
        for tag, si in payload.side_inputs.items():
          data_side_input[transform.unique_name, tag] = (
              create_buffer_id(transform.inputs[tag]), si.access_pattern)
    return data_input, data_side_input, data_output

  # These classes are used to interact with the worker.

  class StateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer):

    class CopyOnWriteState(object):
      def __init__(self, underlying):
        self._underlying = underlying
        self._overlay = {}

      def __getitem__(self, key):
        if key in self._overlay:
          return self._overlay[key]
        else:
          return FnApiRunner.StateServicer.CopyOnWriteList(
              self._underlying, self._overlay, key)

      def __delitem__(self, key):
        self._overlay[key] = []

      def commit(self):
        self._underlying.update(self._overlay)
        return self._underlying

    class CopyOnWriteList(object):
      def __init__(self, underlying, overlay, key):
        self._underlying = underlying
        self._overlay = overlay
        self._key = key

      def __iter__(self):
        if self._key in self._overlay:
          return iter(self._overlay[self._key])
        else:
          return iter(self._underlying[self._key])

      def append(self, item):
        if self._key not in self._overlay:
          self._overlay[self._key] = list(self._underlying[self._key])
        self._overlay[self._key].append(item)

    def __init__(self):
      self._lock = threading.Lock()
      self._state = collections.defaultdict(list)
      self._checkpoint = None
      self._use_continuation_tokens = False
      self._continuations = {}

    def checkpoint(self):
      assert self._checkpoint is None
      self._checkpoint = self._state
      self._state = FnApiRunner.StateServicer.CopyOnWriteState(self._state)

    def commit(self):
      self._state.commit()
      self._state = self._checkpoint.commit()
      self._checkpoint = None

    def restore(self):
      self._state = self._checkpoint
      self._checkpoint = None

    @contextlib.contextmanager
    def process_instruction_id(self, unused_instruction_id):
      yield

    def blocking_get(self, state_key, continuation_token=None):
      with self._lock:
        full_state = self._state[self._to_key(state_key)]
        if self._use_continuation_tokens:
          # The token is "nonce:index".
          if not continuation_token:
            token_base = 'token_%x' % len(self._continuations)
            self._continuations[token_base] = tuple(full_state)
            return b'', '%s:0' % token_base
          else:
            token_base, index = continuation_token.split(':')
            ix = int(index)
            full_state = self._continuations[token_base]
            if ix == len(full_state):
              return b'', None
            else:
              return full_state[ix], '%s:%d' % (token_base, ix + 1)
        else:
          assert not continuation_token
          return b''.join(full_state), None

    def blocking_append(self, state_key, data):
      with self._lock:
        self._state[self._to_key(state_key)].append(data)

    def blocking_clear(self, state_key):
      with self._lock:
        del self._state[self._to_key(state_key)]

    @staticmethod
    def _to_key(state_key):
      return state_key.SerializeToString()

  class GrpcStateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer):
    def __init__(self, state):
      self._state = state

    def State(self, request_stream, context=None):
      # Note that this eagerly mutates state, assuming any failures are fatal.
      # Thus it is safe to ignore instruction_reference.
      for request in request_stream:
        request_type = request.WhichOneof('request')
        if request_type == 'get':
          data, continuation_token = self._state.blocking_get(
              request.state_key, request.get.continuation_token)
          yield beam_fn_api_pb2.StateResponse(
              id=request.id,
              get=beam_fn_api_pb2.StateGetResponse(
                  data=data, continuation_token=continuation_token))
        elif request_type == 'append':
          self._state.blocking_append(request.state_key, request.append.data)
          yield beam_fn_api_pb2.StateResponse(
              id=request.id,
              append=beam_fn_api_pb2.StateAppendResponse())
        elif request_type == 'clear':
          self._state.blocking_clear(request.state_key)
          yield beam_fn_api_pb2.StateResponse(
              id=request.id,
              clear=beam_fn_api_pb2.StateClearResponse())
        else:
          raise NotImplementedError('Unknown state request: %s' % request_type)

  class SingletonStateHandlerFactory(sdk_worker.StateHandlerFactory):
    """A singleton cache for a StateServicer."""

    def __init__(self, state_handler):
      self._state_handler = state_handler

    def create_state_handler(self, api_service_descriptor):
      """Returns the singleton state handler."""
      return self._state_handler

    def close(self):
      """Does nothing."""
      pass


class WorkerHandler(object):
  """worker_handler for a worker.

  It provides utilities to start / stop the worker, provision any resources for
  it, as well as provide descriptors for the data, state and logging APIs for
  it.
  """

  _registered_environments = {}
  _worker_id_counter = -1
  _lock = threading.Lock()

  def __init__(
      self, control_handler, data_plane_handler, state, provision_info):
    """Initialize a WorkerHandler.

    Args:
      control_handler:
      data_plane_handler (data_plane.DataChannel):
      state:
      provision_info:
    """
    self.control_handler = control_handler
    self.data_plane_handler = data_plane_handler
    self.state = state
    self.provision_info = provision_info

    with WorkerHandler._lock:
      WorkerHandler._worker_id_counter += 1
      self.worker_id = 'worker_%s' % WorkerHandler._worker_id_counter

  def close(self):
    self.stop_worker()

  def start_worker(self):
    raise NotImplementedError

  def stop_worker(self):
    raise NotImplementedError

  def data_api_service_descriptor(self):
    raise NotImplementedError

  def state_api_service_descriptor(self):
    raise NotImplementedError

  def logging_api_service_descriptor(self):
    raise NotImplementedError

  @classmethod
  def register_environment(cls, urn, payload_type):
    def wrapper(constructor):
      cls._registered_environments[urn] = constructor, payload_type
      return constructor
    return wrapper

  @classmethod
  def create(cls, environment, state, provision_info, grpc_server):
    constructor, payload_type = cls._registered_environments[environment.urn]
    return constructor(
        proto_utils.parse_Bytes(environment.payload, payload_type),
        state,
        provision_info,
        grpc_server)


@WorkerHandler.register_environment(python_urns.EMBEDDED_PYTHON, None)
class EmbeddedWorkerHandler(WorkerHandler):
  """An in-memory worker_handler for fn API control, state and data planes."""

  def __init__(self, unused_payload, state, provision_info,
               unused_grpc_server=None):
    super(EmbeddedWorkerHandler, self).__init__(
        self, data_plane.InMemoryDataChannel(), state, provision_info)
    self.control_conn = self
    self.data_conn = self.data_plane_handler

    self.worker = sdk_worker.SdkWorker(
        sdk_worker.BundleProcessorCache(
            FnApiRunner.SingletonStateHandlerFactory(self.state),
            data_plane.InMemoryDataChannelFactory(
                self.data_plane_handler.inverse()),
            {}))
    self._uid_counter = 0

  def push(self, request):
    if not request.instruction_id:
      self._uid_counter += 1
      request.instruction_id = 'control_%s' % self._uid_counter
    response = self.worker.do_instruction(request)
    return ControlFuture(request.instruction_id, response)

  def start_worker(self):
    pass

  def stop_worker(self):
    self.worker.stop()

  def done(self):
    pass

  def data_api_service_descriptor(self):
    return None

  def state_api_service_descriptor(self):
    return None

  def logging_api_service_descriptor(self):
    return None


class BasicLoggingService(beam_fn_api_pb2_grpc.BeamFnLoggingServicer):

  LOG_LEVEL_MAP = {
      beam_fn_api_pb2.LogEntry.Severity.CRITICAL: logging.CRITICAL,
      beam_fn_api_pb2.LogEntry.Severity.ERROR: logging.ERROR,
      beam_fn_api_pb2.LogEntry.Severity.WARN: logging.WARNING,
      beam_fn_api_pb2.LogEntry.Severity.NOTICE: logging.INFO + 1,
      beam_fn_api_pb2.LogEntry.Severity.INFO: logging.INFO,
      beam_fn_api_pb2.LogEntry.Severity.DEBUG: logging.DEBUG,
      beam_fn_api_pb2.LogEntry.Severity.TRACE: logging.DEBUG - 1,
      beam_fn_api_pb2.LogEntry.Severity.UNSPECIFIED: logging.NOTSET,
  }

  def Logging(self, log_messages, context=None):
    yield beam_fn_api_pb2.LogControl()
    for log_message in log_messages:
      for log in log_message.log_entries:
        logging.log(self.LOG_LEVEL_MAP[log.severity], str(log))


class BasicProvisionService(
    beam_provision_api_pb2_grpc.ProvisionServiceServicer):

  def __init__(self, info):
    self._info = info

  def GetProvisionInfo(self, request, context=None):
    return beam_provision_api_pb2.GetProvisionInfoResponse(
        info=self._info)


class GrpcServer(object):

  _DEFAULT_SHUTDOWN_TIMEOUT_SECS = 5

  def __init__(self, state, provision_info):
    self.state = state
    self.provision_info = provision_info
    self.control_server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=10))
    self.control_port = self.control_server.add_insecure_port('[::]:0')
    self.control_address = 'localhost:%s' % self.control_port

    # Options to have no limits (-1) on the size of the messages
    # received or sent over the data plane. The actual buffer size
    # is controlled in a layer above.
    no_max_message_sizes = [("grpc.max_receive_message_length", -1),
                            ("grpc.max_send_message_length", -1)]
    self.data_server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=10),
        options=no_max_message_sizes)
    self.data_port = self.data_server.add_insecure_port('[::]:0')

    self.state_server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=10),
        options=no_max_message_sizes)
    self.state_port = self.state_server.add_insecure_port('[::]:0')

    self.control_handler = BeamFnControlServicer()
    beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
        self.control_handler, self.control_server)

    # If we have provision info, serve these off the control port as well.
    if self.provision_info:
      if self.provision_info.provision_info:
        provision_info = self.provision_info.provision_info
        if not provision_info.worker_id:
          provision_info = copy.copy(provision_info)
          provision_info.worker_id = str(uuid.uuid4())
        beam_provision_api_pb2_grpc.add_ProvisionServiceServicer_to_server(
            BasicProvisionService(self.provision_info.provision_info),
            self.control_server)

      if self.provision_info.artifact_staging_dir:
        m = beam_artifact_api_pb2_grpc
        m.add_ArtifactRetrievalServiceServicer_to_server(
            artifact_service.BeamFilesystemArtifactService(
                self.provision_info.artifact_staging_dir),
            self.control_server)

    self.data_plane_handler = data_plane.BeamFnDataServicer()
    beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(
        self.data_plane_handler, self.data_server)

    beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server(
        FnApiRunner.GrpcStateServicer(state),
        self.state_server)

    self.logging_server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=2),
        options=no_max_message_sizes)
    self.logging_port = self.logging_server.add_insecure_port('[::]:0')
    beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server(
        BasicLoggingService(),
        self.logging_server)

    logging.info('starting control server on port %s', self.control_port)
    logging.info('starting data server on port %s', self.data_port)
    logging.info('starting state server on port %s', self.state_port)
    logging.info('starting logging server on port %s', self.logging_port)
    self.logging_server.start()
    self.state_server.start()
    self.data_server.start()
    self.control_server.start()

  def close(self):
    self.control_handler.done()
    to_wait = [
        self.control_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS),
        self.data_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS),
        self.state_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS),
        self.logging_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS)
    ]
    for w in to_wait:
      w.wait()


class GrpcWorkerHandler(WorkerHandler):
  """An grpc based worker_handler for fn API control, state and data planes."""

  def __init__(self, state, provision_info, grpc_server):
    self._grpc_server = grpc_server
    super(GrpcWorkerHandler, self).__init__(
        self._grpc_server.control_handler, self._grpc_server.data_plane_handler,
        state, provision_info)
    self.state = state

    self.control_address = self._grpc_server.control_address
    self.control_conn = self._grpc_server.control_handler \
      .get_conn_by_worker_id(self.worker_id)

    self.data_conn = self._grpc_server.data_plane_handler \
      .get_conn_by_worker_id(self.worker_id)

  def data_api_service_descriptor(self):
    return endpoints_pb2.ApiServiceDescriptor(
        url='localhost:%s' % self._grpc_server.data_port)

  def state_api_service_descriptor(self):
    return endpoints_pb2.ApiServiceDescriptor(
        url='localhost:%s' % self._grpc_server.state_port)

  def logging_api_service_descriptor(self):
    return endpoints_pb2.ApiServiceDescriptor(
        url='localhost:%s' % self._grpc_server.logging_port)

  def close(self):
    self.control_conn.close()
    self.data_conn.close()
    super(GrpcWorkerHandler, self).close()


@WorkerHandler.register_environment(
    common_urns.environments.EXTERNAL.urn, beam_runner_api_pb2.ExternalPayload)
class ExternalWorkerHandler(GrpcWorkerHandler):
  def __init__(self, external_payload, state, provision_info, grpc_server):
    super(ExternalWorkerHandler, self).__init__(state, provision_info,
                                                grpc_server)
    self._external_payload = external_payload

  def start_worker(self):
    stub = beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolStub(
        GRPCChannelFactory.insecure_channel(
            self._external_payload.endpoint.url))
    response = stub.NotifyRunnerAvailable(
        beam_fn_api_pb2.NotifyRunnerAvailableRequest(
            worker_id=self.worker_id,
            control_endpoint=endpoints_pb2.ApiServiceDescriptor(
                url=self.control_address),
            logging_endpoint=self.logging_api_service_descriptor(),
            params=self._external_payload.params))
    if response.error:
      raise RuntimeError("Error starting worker: %s" % response.error)

  def stop_worker(self):
    pass


@WorkerHandler.register_environment(python_urns.EMBEDDED_PYTHON_GRPC, bytes)
class EmbeddedGrpcWorkerHandler(GrpcWorkerHandler):
  def __init__(self, num_workers_payload, state, provision_info, grpc_server):
    super(EmbeddedGrpcWorkerHandler, self).__init__(state, provision_info,
                                                    grpc_server)
    self._num_threads = int(num_workers_payload) if num_workers_payload else 1

  def start_worker(self):
    self.worker = sdk_worker.SdkHarness(
        self.control_address, worker_count=self._num_threads,
        worker_id=self.worker_id)
    self.worker_thread = threading.Thread(
        name='run_worker', target=self.worker.run)
    self.worker_thread.daemon = True
    self.worker_thread.start()

  def stop_worker(self):
    self.worker_thread.join()


@WorkerHandler.register_environment(python_urns.SUBPROCESS_SDK, bytes)
class SubprocessSdkWorkerHandler(GrpcWorkerHandler):
  def __init__(self, worker_command_line, state, provision_info, grpc_server):
    super(SubprocessSdkWorkerHandler, self).__init__(state, provision_info,
                                                     grpc_server)
    self._worker_command_line = worker_command_line

  def start_worker(self):
    from apache_beam.runners.portability import local_job_service
    self.worker = local_job_service.SubprocessSdkWorker(
        self._worker_command_line, self.control_address, self.worker_id)
    self.worker_thread = threading.Thread(
        name='run_worker', target=self.worker.run)
    self.worker_thread.start()

  def stop_worker(self):
    self.worker_thread.join()


@WorkerHandler.register_environment(common_urns.environments.DOCKER.urn,
                                    beam_runner_api_pb2.DockerPayload)
class DockerSdkWorkerHandler(GrpcWorkerHandler):
  def __init__(self, payload, state, provision_info, grpc_server):
    super(DockerSdkWorkerHandler, self).__init__(state, provision_info,
                                                 grpc_server)
    self._container_image = payload.container_image
    self._container_id = None

  def start_worker(self):
    try:
      subprocess.check_call(['docker', 'pull', self._container_image])
    except Exception:
      logging.info('Unable to pull image %s' % self._container_image)
    self._container_id = subprocess.check_output(
        ['docker',
         'run',
         '-d',
         # TODO:  credentials
         '--network=host',
         self._container_image,
         '--id=%s' % uuid.uuid4(),
         '--logging_endpoint=%s' % self.logging_api_service_descriptor().url,
         '--control_endpoint=%s' % self.control_address,
         '--artifact_endpoint=%s' % self.control_address,
         '--provision_endpoint=%s' % self.control_address,
        ]).strip()
    while True:
      logging.info('Waiting for docker to start up...')
      status = subprocess.check_output([
          'docker',
          'inspect',
          '-f',
          '{{.State.Status}}',
          self._container_id]).strip()
      if status == 'running':
        break
      elif status in ('dead', 'exited'):
        subprocess.call([
            'docker',
            'container',
            'logs',
            self._container_id])
        raise RuntimeError('SDK failed to start.')
      time.sleep(1)

  def stop_worker(self):
    if self._container_id:
      subprocess.call([
          'docker',
          'kill',
          self._container_id])


class WorkerHandlerManager(object):
  def __init__(self, environments, job_provision_info=None):
    self._environments = environments
    self._job_provision_info = job_provision_info
    self._cached_handlers = collections.defaultdict(list)
    self._state = FnApiRunner.StateServicer() # rename?
    self._grpc_server = None

  def get_worker_handlers(self, environment_id, num_workers):
    if environment_id is None:
      # Any environment will do, pick one arbitrarily.
      environment_id = next(iter(self._environments.keys()))
    environment = self._environments[environment_id]
    # assume it's using grpc if environment is not EMBEDDED_PYTHON.
    if environment.urn != python_urns.EMBEDDED_PYTHON and \
        self._grpc_server is None:
      self._grpc_server = GrpcServer(self._state, self._job_provision_info)
    worker_handler_list = self._cached_handlers[environment_id]
    if len(worker_handler_list) < num_workers:
      for _ in range(len(worker_handler_list), num_workers):
        worker_handler = WorkerHandler.create(
            environment, self._state, self._job_provision_info,
            self._grpc_server)
        self._cached_handlers[environment_id].append(worker_handler)
        worker_handler.start_worker()
    return self._cached_handlers[environment_id][:num_workers]

  def close_all(self):
    for worker_handler_list in self._cached_handlers.values():
      for worker_handler in set(worker_handler_list):
        try:
          worker_handler.close()
        except Exception:
          logging.error("Error closing worker_handler %s" % worker_handler,
                        exc_info=True)
    self._cached_handlers = {}
    if self._grpc_server is not None:
      self._grpc_server.close()
      self._grpc_server = None


class ExtendedProvisionInfo(object):
  def __init__(self, provision_info=None, artifact_staging_dir=None):
    self.provision_info = (
        provision_info or beam_provision_api_pb2.ProvisionInfo())
    self.artifact_staging_dir = artifact_staging_dir


_split_managers = []


@contextlib.contextmanager
def split_manager(stage_name, split_manager):
  """Registers a split manager to control the flow of elements to a given stage.

  Used for testing.

  A split manager should be a coroutine yielding desired split fractions,
  receiving the corresponding split results. Currently, only one input is
  supported.
  """
  try:
    _split_managers.append((stage_name, split_manager))
    yield
  finally:
    _split_managers.pop()


class BundleManager(object):
  """Manages the execution of a bundle from the runner-side.

  This class receives a bundle descriptor, and performs the following tasks:
  - Registration of the bundle with the worker.
  - Splitting of the bundle
  - Setting up any other bundle requirements (e.g. side inputs).
  - Submitting the bundle to worker for execution
  - Passing bundle input data to the worker
  - Collecting bundle output data from the worker
  - Finalizing the bundle.
  """

  _uid_counter = 0
  _lock = threading.Lock()

  def __init__(
      self, worker_handler_list, get_buffer, get_input_coder_impl,
      bundle_descriptor, progress_frequency=None, skip_registration=False):
    """Set up a bundle manager.

    Args:
      worker_handler_list
      get_buffer (Callable[[str], list])
      get_input_coder_impl (Callable[[str], Coder])
      bundle_descriptor (beam_fn_api_pb2.ProcessBundleDescriptor)
      progress_frequency
      skip_registration
    """
    self._worker_handler_list = worker_handler_list
    self._get_buffer = get_buffer
    self._get_input_coder_impl = get_input_coder_impl
    self._bundle_descriptor = bundle_descriptor
    self._registered = skip_registration
    self._progress_frequency = progress_frequency
    self._worker_handler = None

  def _send_input_to_worker(self,
                            process_bundle_id,
                            read_transform_id,
                            byte_streams):
    data_out = self._worker_handler.data_conn.output_stream(
        process_bundle_id, read_transform_id)
    for byte_stream in byte_streams:
      data_out.write(byte_stream)
    data_out.close()

  def _register_bundle_descriptor(self):
    if self._registered:
      registration_future = None
    else:
      process_bundle_registration = beam_fn_api_pb2.InstructionRequest(
          register=beam_fn_api_pb2.RegisterRequest(
              process_bundle_descriptor=[self._bundle_descriptor]))
      registration_future = self._worker_handler.control_conn.push(
          process_bundle_registration)
      self._registered = True

    return registration_future

  def _select_split_manager(self):
    """TODO(pabloem) WHAT DOES THIS DO"""
    unique_names = set(
        t.unique_name for t in self._bundle_descriptor.transforms.values())
    for stage_name, candidate in reversed(_split_managers):
      if (stage_name in unique_names
          or (stage_name + '/Process') in unique_names):
        split_manager = candidate
        break
    else:
      split_manager = None

    return split_manager

  def _generate_splits_for_testing(self,
                                   split_manager,
                                   inputs,
                                   process_bundle_id):
    split_results = []
    read_transform_id, buffer_data = only_element(inputs.items())

    byte_stream = b''.join(buffer_data)
    num_elements = len(list(
        self._get_input_coder_impl(read_transform_id).decode_all(byte_stream)))

    # Start the split manager in case it wants to set any breakpoints.
    split_manager_generator = split_manager(num_elements)
    try:
      split_fraction = next(split_manager_generator)
      done = False
    except StopIteration:
      done = True

    # Send all the data.
    self._send_input_to_worker(
        process_bundle_id, read_transform_id, [byte_stream])

    # Execute the requested splits.
    while not done:
      if split_fraction is None:
        split_result = None
      else:
        split_request = beam_fn_api_pb2.InstructionRequest(
            process_bundle_split=
            beam_fn_api_pb2.ProcessBundleSplitRequest(
                instruction_reference=process_bundle_id,
                desired_splits={
                    read_transform_id:
                    beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit(
                        fraction_of_remainder=split_fraction,
                        estimated_input_elements=num_elements)
                }))
        split_response = self._worker_handler.control_conn.push(
            split_request).get()
        for t in (0.05, 0.1, 0.2):
          waiting = ('Instruction not running', 'not yet scheduled')
          if any(msg in split_response.error for msg in waiting):
            time.sleep(t)
            split_response = self._worker_handler.control_conn.push(
                split_request).get()
        if 'Unknown process bundle' in split_response.error:
          # It may have finished too fast.
          split_result = None
        elif split_response.error:
          raise RuntimeError(split_response.error)
        else:
          split_result = split_response.process_bundle_split
          split_results.append(split_result)
      try:
        split_fraction = split_manager_generator.send(split_result)
      except StopIteration:
        break
    return split_results

  def process_bundle(self, inputs, expected_outputs):
    # Unique id for the instruction processing this bundle.
    with BundleManager._lock:
      BundleManager._uid_counter += 1
      process_bundle_id = 'bundle_%s' % BundleManager._uid_counter
      self._worker_handler = self._worker_handler_list[
          BundleManager._uid_counter % len(self._worker_handler_list)]

    # Register the bundle descriptor, if needed - noop if already registered.
    registration_future = self._register_bundle_descriptor()
    # Check that the bundle was successfully registered.
    if registration_future and registration_future.get().error:
      raise RuntimeError(registration_future.get().error)

    split_manager = self._select_split_manager()
    if not split_manager:
      # If there is no split_manager, write all input data to the channel.
      for transform_id, elements in inputs.items():
        self._send_input_to_worker(
            process_bundle_id, transform_id, elements)

    # Actually start the bundle.
    process_bundle_req = beam_fn_api_pb2.InstructionRequest(
        instruction_id=process_bundle_id,
        process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
            process_bundle_descriptor_reference=self._bundle_descriptor.id))
    result_future = self._worker_handler.control_conn.push(process_bundle_req)

    split_results = []
    with ProgressRequester(
        self._worker_handler, process_bundle_id, self._progress_frequency):

      if split_manager:
        split_results = self._generate_splits_for_testing(
            split_manager, inputs, process_bundle_id)

      # Gather all output data.
      for output in self._worker_handler.data_conn.input_elements(
          process_bundle_id,
          expected_outputs.keys(),
          abort_callback=lambda: (result_future.is_done()
                                  and result_future.get().error)):
        if output.ptransform_id in expected_outputs:
          with BundleManager._lock:
            self._get_buffer(
                expected_outputs[output.ptransform_id]).append(output.data)

      logging.debug('Wait for the bundle %s to finish.' % process_bundle_id)
      result = result_future.get()

    if result.error:
      raise RuntimeError(result.error)

    if result.process_bundle.requires_finalization:
      finalize_request = beam_fn_api_pb2.InstructionRequest(
          finalize_bundle=
          beam_fn_api_pb2.FinalizeBundleRequest(
              instruction_reference=process_bundle_id
          ))
      self._worker_handler.control_conn.push(finalize_request)

    return result, split_results


class ParallelBundleManager(BundleManager):

  def __init__(
      self, worker_handler_list, get_buffer, get_input_coder_impl,
      bundle_descriptor, progress_frequency=None, skip_registration=False,
      **kwargs):
    super(ParallelBundleManager, self).__init__(
        worker_handler_list, get_buffer, get_input_coder_impl,
        bundle_descriptor, progress_frequency, skip_registration)
    self._num_workers = kwargs.pop('num_workers', 1)

  def process_bundle(self, inputs, expected_outputs):
    part_inputs = [{} for _ in range(self._num_workers)]
    for name, input in inputs.items():
      for ix, part in enumerate(input.partition(self._num_workers)):
        part_inputs[ix][name] = part

    merged_result = None
    split_result_list = []
    with futures.ThreadPoolExecutor(max_workers=self._num_workers) as executor:
      for result, split_result in executor.map(lambda part: BundleManager(
          self._worker_handler_list, self._get_buffer,
          self._get_input_coder_impl, self._bundle_descriptor,
          self._progress_frequency, self._registered).process_bundle(
              part, expected_outputs), part_inputs):

        split_result_list += split_result
        if merged_result is None:
          merged_result = result
        else:
          merged_result = beam_fn_api_pb2.InstructionResponse(
              process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
                  monitoring_infos=monitoring_infos.consolidate(
                      itertools.chain(
                          result.process_bundle.monitoring_infos,
                          merged_result.process_bundle.monitoring_infos))),
              error=result.error or merged_result.error)

    return merged_result, split_result_list


class ProgressRequester(threading.Thread):
  """ Thread that asks SDK Worker for progress reports with a certain frequency.

  A callback can be passed to call with progress updates.
  """

  def __init__(self, worker_handler, instruction_id, frequency, callback=None):
    super(ProgressRequester, self).__init__()
    self._worker_handler = worker_handler
    self._instruction_id = instruction_id
    self._frequency = frequency
    self._done = False
    self._latest_progress = None
    self._callback = callback
    self.daemon = True

  def __enter__(self):
    if self._frequency:
      self.start()

  def __exit__(self, *unused_exc_info):
    if self._frequency:
      self.stop()

  def run(self):
    while not self._done:
      try:
        progress_result = self._worker_handler.control_conn.push(
            beam_fn_api_pb2.InstructionRequest(
                process_bundle_progress=
                beam_fn_api_pb2.ProcessBundleProgressRequest(
                    instruction_reference=self._instruction_id))).get()
        self._latest_progress = progress_result.process_bundle_progress
        if self._callback:
          self._callback(self._latest_progress)
      except Exception as exn:
        logging.error("Bad progress: %s", exn)
      time.sleep(self._frequency)

  def stop(self):
    self._done = True


class ControlFuture(object):
  def __init__(self, instruction_id, response=None):
    self.instruction_id = instruction_id
    if response:
      self._response = response
    else:
      self._response = None
      self._condition = threading.Condition()

  def is_done(self):
    return self._response is not None

  def set(self, response):
    with self._condition:
      self._response = response
      self._condition.notify_all()

  def get(self, timeout=None):
    if not self._response:
      with self._condition:
        if not self._response:
          self._condition.wait(timeout)
    return self._response


class FnApiMetrics(metrics.metric.MetricResults):
  def __init__(self, step_monitoring_infos, user_metrics_only=True):
    """Used for querying metrics from the PipelineResult object.

      step_monitoring_infos: Per step metrics specified as MonitoringInfos.
      use_monitoring_infos: If true, return the metrics based on the
          step_monitoring_infos.
    """
    self._counters = {}
    self._distributions = {}
    self._gauges = {}
    self._user_metrics_only = user_metrics_only
    self._init_metrics_from_monitoring_infos(step_monitoring_infos)
    self._monitoring_infos = step_monitoring_infos

  def _init_metrics_from_monitoring_infos(self, step_monitoring_infos):
    for smi in step_monitoring_infos.values():
      # Only include user metrics.
      for mi in smi:
        if (self._user_metrics_only and
            not monitoring_infos.is_user_monitoring_info(mi)):
          continue
        key = self._to_metric_key(mi)
        if monitoring_infos.is_counter(mi):
          self._counters[key] = (
              monitoring_infos.extract_metric_result_map_value(mi))
        elif monitoring_infos.is_distribution(mi):
          self._distributions[key] = (
              monitoring_infos.extract_metric_result_map_value(mi))
        elif monitoring_infos.is_gauge(mi):
          self._gauges[key] = (
              monitoring_infos.extract_metric_result_map_value(mi))

  def _to_metric_key(self, monitoring_info):
    # Right now this assumes that all metrics have a PTRANSFORM
    ptransform_id = monitoring_info.labels['PTRANSFORM']
    namespace, name = monitoring_infos.parse_namespace_and_name(monitoring_info)
    return MetricKey(ptransform_id, MetricName(namespace, name))

  def query(self, filter=None):
    counters = [metrics.execution.MetricResult(k, v, v)
                for k, v in self._counters.items()
                if self.matches(filter, k)]
    distributions = [metrics.execution.MetricResult(k, v, v)
                     for k, v in self._distributions.items()
                     if self.matches(filter, k)]
    gauges = [metrics.execution.MetricResult(k, v, v)
              for k, v in self._gauges.items()
              if self.matches(filter, k)]

    return {self.COUNTERS: counters,
            self.DISTRIBUTIONS: distributions,
            self.GAUGES: gauges}

  def monitoring_infos(self):
    return [item for sublist in self._monitoring_infos.values() for item in
            sublist]


class RunnerResult(runner.PipelineResult):
  def __init__(self, state, monitoring_infos_by_stage, metrics_by_stage):
    super(RunnerResult, self).__init__(state)
    self._monitoring_infos_by_stage = monitoring_infos_by_stage
    self._metrics_by_stage = metrics_by_stage
    self._metrics = None
    self._monitoring_metrics = None

  def wait_until_finish(self, duration=None):
    return self._state

  def metrics(self):
    """Returns a queryable object including user metrics only."""
    if self._metrics is None:
      self._metrics = FnApiMetrics(
          self._monitoring_infos_by_stage, user_metrics_only=True)
    return self._metrics

  def monitoring_metrics(self):
    """Returns a queryable object including all metrics."""
    if self._monitoring_metrics is None:
      self._monitoring_metrics = FnApiMetrics(
          self._monitoring_infos_by_stage, user_metrics_only=False)
    return self._monitoring_metrics
